- Notifications
You must be signed in to change notification settings - Fork 251
/
Copy pathAnFany_DT_Classify.py
462 lines (396 loc) · 16.6 KB
/
AnFany_DT_Classify.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
# -*- coding:utf-8 -*-
# &Author AnFany
# CART分类树:可处理连续、离散的变量,支持多分类
# 测试数据和训练数据的字段顺序必须一样,因为本程序在设定规则按的是字段的编号,而不是名字
# 引入数据
importDT_Classify_Dataasdtda
importcopy
importpandasaspd
importnumpyasnp
# 定义函数
classDT:
def__init__(self, train_dtdata=dtda.dt_data, pre_dtdata=dtda.test_data, tree_length=4):
# 训练数据
self.train_dtdata=train_dtdata[0]['train']
# 验证数据
self.test_dtdata=train_dtdata[0]['test']
# 预测数据
self.pre_dtdata=pre_dtdata
# 中间过程变量
self.node_shujuji= {'0': self.train_dtdata} # 存储每一个节点数据集的字典
self.fenlei_shujuji= {'0': self.train_dtdata} # 存储需要分类的数据集的字典
# 叶子节点的集合
self.leafnodes= []
# 节点关系字典
self.noderela= {}
# 每一个节点的规则关系
self.node_rule= {'0': []}
# 避免树过大,采用限制书的深度以及基尼系数的值
self.tree_length=tree_length
# 根据类别的数组计算基尼指数
defjini_zhishu(self, exlist):
dnum=0
leng=len(exlist)
forhhinlist(set(exlist)):
dnum+= (list(exlist).count(hh) /leng) **2
return1-dnum
# 计算基尼系数的函数
defjini_xishu(self, tezheng, leibie): # 输入特征数据,类别数据,返回最小基尼系数对应的值
# 首先判断特征数据是连续、或者是分类的
sign=0
try:
tezheng[0] +2
# 证明是连续的
sign=1
exceptTypeError:
pass
ifsign: # 连续变量
# 去重、排序
quzhong=np.array(sorted(list(set(tezheng))))
# 判断是不是就一个值
iflen(quzhong) ==1:
returnFalse
# 取中间值
midd= (quzhong[:-1] +quzhong[1:]) /2
# 开始遍历每一个中间值,计算对应的基尼系数
length=len(leibie)
# 存储基尼系数的值
save_ji, jini=np.inf, 0
number=''
formiinmidd:
# 计算基尼系数
onelist=leibie[tezheng<=mi]
twolist=leibie[tezheng>mi]
jini= (len(onelist) /length) *self.jini_zhishu(onelist) + (len(twolist) /length) *self.jini_zhishu(twolist)
ifjini<=save_ji:
save_ji=jini
number=mi
returnnumber, save_ji
else: #分类变量
# 去重、排序
quzhong=np.array(list(set(tezheng)))
# 判断是不是就一个值
iflen(quzhong) ==1:
returnFalse
# 开始遍历每一个值,计算对应的基尼系数
length=len(leibie)
# 存储基尼系数的值
jini, save_ji=0, np.inf
number=''
formiinquzhong:
# 计算基尼系数
onelist=leibie[tezheng==mi]
twolist=leibie[tezheng!=mi]
jini= (len(onelist) /length) *self.jini_zhishu(onelist) + (len(twolist) /length) *self.jini_zhishu(
twolist)
ifjini<=save_ji:
save_ji=jini
number=mi
returnnumber, save_ji# 该特征最好的分割值,以及该特征最小的基尼系数
# 数据集确定分类特征以及属性的函数
deffeature_zhi(self, datadist): # 输入的数据集字典,输出最优的特征编号,以及对应的值,还有基尼系数
tezhengsign=''
number=np.inf
jini=''
forjilinrange(1, len(datadist[0])):
# 获取特征数据和类别数据
tezhen=datadist[:, (jil-1): jil].T[0]
leib=datadist[:, -1:].T[0]
# 在其中选择最小的
cresu=self.jini_xishu(tezhen, leib)
# 判断这个特征可不可取
ifcresu:
ifcresu[1] <=number:
number=cresu[1]
tezhengsign=jil-1
jini=cresu[0]
ifjini!='':
returntezhengsign, jini, number# 特征编号, 该特征最好的分割值,该数据集最小的基尼系数
else:
returnFalse# 这个数据集无法被分裂
# 将数据集合分裂
defdevided_shujuji(self, datadis): # 输入特征编号,对应的值,返回两个数据集
# 运算的结果
yuansuan=self.feature_zhi(datadis)
ifyuansuan:
# 需要判断这个被选中的特征是连续还是离散的
try:
datadis[:, yuansuan[0]][0] +2
oneshujui=datadis[datadis[:, yuansuan[0]] <=yuansuan[1]]
twoshujui=datadis[datadis[:, yuansuan[0]] >yuansuan[1]]
exceptTypeError:
oneshujui=datadis[datadis[:, yuansuan[0]] ==yuansuan[1]]
twoshujui=datadis[datadis[:, yuansuan[0]] !=yuansuan[1]]
returnoneshujui, twoshujui, yuansuan
else:
returnFalse
# 决策树函数
defgrow_tree(self):
whilelen(self.fenlei_shujuji) !=0:
# 需要复制字典
copy_dict=copy.deepcopy(self.fenlei_shujuji)
# 开始遍历每一个需要分类的数据集
forhdinself.fenlei_shujuji:
# 在这里限制树的深度
iflen(hd) ==self.tree_length+1:
# 不需要在分裂
delcopy_dict[hd]
# 添加到叶子节点的集合中
self.leafnodes.append(hd)
else:
fenguo=self.devided_shujuji(copy_dict[hd])
iffenguo:
iflen(set(fenguo[0][:, -1])) ==1: # 数据集是一个类别就不再分裂
self.leafnodes.append('%sl'%hd) # 成叶子节点
else:
copy_dict['%sl'%hd] =fenguo[0] # 继续分裂
self.node_shujuji['%sl'%hd] =fenguo[0] # 总的数据集
# 添加节点的规则
self.node_rule['%sl'%hd] = (self.node_rule[hd]).copy()
self.node_rule['%sl'%hd].append(fenguo[2])
iflen(set(fenguo[1][:, -1])) ==1:
self.leafnodes.append('%sr'%hd)
else:
copy_dict['%sr'%hd] =fenguo[1]
self.node_shujuji['%sr'%hd] =fenguo[1]
# 添加节点的规则
self.node_rule['%sr'%hd] = (self.node_rule[hd]).copy()
self.node_rule['%sr'%hd].append(fenguo[2])
# 添加到节点关系字典
self.noderela[hd] = ['%sl'%hd, '%sr'%hd]
delcopy_dict[hd] # 需要在分裂数据中删除这一个
self.fenlei_shujuji=copy.deepcopy(copy_dict)
print('所有节点的个数:', len(self.fenlei_shujuji))
print('需要分裂的数据集的个数:', len(self.node_shujuji))
return'done'
# 根据树得出每一个节点数据集的结果
defjieguo_tree(self):
# 根据每一个数据得到每一个节点对应的结果
shujuji_jieguo= {}
forshujuinself.node_shujuji:
zuihang=self.node_shujuji[shuju][:, -1]
# 选择最多的
duodict= {ik: list(zuihang).count(ik) forikinset(list(zuihang))}
# 在其中选择最多的
shujuji_jieguo[shuju] =max(duodict.items(), key=lambdadw: dw[1])[0]
returnshujuji_jieguo
# 要得到叶子节点的集合
defleafnodes_tree(self):
# 不在键值中的所有节点
keynodes=list(self.noderela.keys())
zhin=list(self.noderela.values())
zhinodes= []
forhhuinzhin:
forfffinhhu:
zhinodes.append(fff)
leafnodes= [jjforjjinzhinodesifjjnotinkeynodes]
returnleafnodes
# 寻找任何一个内部节点的叶子节点
definer_leaf(self, exnode):
# 内部节点
inernodes=list(self.noderela.keys())
# 叶子节点
llnodes= []
# 全部的节点
ghunodes=list(self.noderela.values())
gugu= []
forhhddinghunodes:
forghghinhhdd:
gugu.append(ghgh)
forjjingugu+ ['0']:
ifjjnotininernodes:
iflen(jj) >len(exnode) andexnodeinjj:
llnodes.append(jj)
returnllnodes
# 寻找任何一个内部节点的下属的节点
defxiashu_leaf(self, exnode):
# 叶子节点
xiashunodes= []
# 全部的节点
godes=list(self.noderela.values())
gug= []
forhhddingodes:
forghghinhhdd:
gug.append(ghgh)
forjjingug+ ['0']:
ifexnodeinjj:
xiashunodes.append(jj)
returnxiashunodes
# 判读数据是否符合这个规矩的函数
defjudge_data(self, data, signstr, guize):
# 首先判断数据连续或者是离散
fign=0
try:
data[guize[0]] +2
fign=1
exceptTypeError:
pass
iffign==1: # 连续
ifsignstr=='r':
ifdata[guize[0]] >guize[1]:
returnTrue
returnFalse
elifsignstr=='l':
ifdata[guize[0]] <=guize[1]:
returnTrue
returnFalse
eliffign==0: # 离散
ifsignstr=='r':
ifdata[guize[0]] !=guize[1]:
returnTrue
returnFalse
elifsignstr=='l':
ifdata[guize[0]] ==guize[1]:
returnTrue
returnFalse
# 预测函数, 根据节点的关系字典以及规则、每个节点的结果获得预测数据的结果
defpre_tree(self, predata):
# 每个数据集合的结果
meire=self.jieguo_tree()
# 存储结果
savresu= []
# 首先根据节点关系找到所有的叶子节点
yezinodes=self.leafnodes_tree()
# 开始判断数据
forjjinpredata:
shuju=jj[: -1]
# 开始判断
foryyinyezinodes:
gu=1
guide=self.node_rule[yy]
foriu, juinzip(yy[1:], guide):
ifnotself.judge_data(shuju, iu, ju):
gu=0
break
ifgu==1:
savresu.append(meire[yy])
returnsavresu
# 计算每一个节点的剪枝的基尼系数
defjianzhi_iner(self, exnode):
# 首先得到整体训练数据集的长度
leng=len(self.train_dtdata)
# # 在得到本节点数据集的长度,此项可以被消去
# benleng = len(self.node_shujuji[exnode])
# 计算被错误分类的数据的条数
self.node_result=self.jieguo_tree()
cuowu_leng=len(self.node_shujuji[exnode][self.node_shujuji[exnode][:, -1] !=self.node_result[exnode]])
# 计算
jinum=cuowu_leng/leng
returnjinum
# 计算每一个内部节点的下属叶子节点的基尼系数之和
definer_sum(self, ecnode):
jnum=0
# 首先得到这个内部节点下属的所有叶子节点
forhhhinself.iner_leaf(ecnode):
jnum+=self.jianzhi_iner(hhh)
returnjnum
# 树的剪枝, 每一棵树都是一个字典形式(节点关系就代表一棵子树)
defprue_tree(self):
# 开始剪枝
tree_set= {}
# a值的字典
adict= {}
# 第一棵完全生长的树
sign=0
tree_set[sign] =self.noderela.copy()
# 开始剪枝
whilelen(list(self.noderela.keys())) !=0:
# 复制字典
coppdict=self.noderela.copy()
# 存储内部节点剪枝基尼系数的字典
saveiner= {}
forjinerinlist(self.noderela.keys()):
# 每一个内部节点计算
saveiner[jiner] = (self.jianzhi_iner(jiner) -self.iner_sum(jiner)) / (len(self.iner_leaf(jiner)) -1)
# 选择其中最小的,如果有2个相同的选择最长的
numm=np.inf
dd=''
forhjiinsaveiner:
ifnumm>saveiner[hji]:
dd=hji
numm=saveiner[hji]
elifnumm==saveiner[hji]:
iflen(dd) <len(hji):
dd=hji
# 添加到a值
adict[sign] =numm
# 需要删除hji这个内部节点
# 首选得到这个内部节点所有的
forhcoinself.xiashu_leaf(dd):
ifhcoincoppdict:
delcoppdict[hco]
# 树加1
sign+=1
self.noderela=coppdict.copy()
tree_set[sign] =self.noderela.copy()
returntree_set, adict
# 计算正确率的函数
defcompuer_correct(self, exli_real, exli_pre):
iflen(exli_pre) ==0:
return0
else:
corr=np.array(exli_pre)[np.array(exli_pre) ==np.array(exli_real)]
returnlen(corr) /len(exli_pre)
# 交叉验证函数
defjiaocha_tree(self, treeset): #输出最终的树
# 正确率的字典
correct= {}
# 遍历树的集合
forjjintreeset:
self.noderela=treeset[jj]
yuce=self.pre_tree(self.test_dtdata)
# 真实的预测值
real=self.test_dtdata[:, -1]
# 计算正确率
correct[jj] =self.compuer_correct(real, yuce)
# 获得最大的,如果有相同的,获取数目最小的键
num=0
leys=''
forjjincorrect:
ifcorrect[jj] >num:
num=correct[jj]
leys=jj
elifnum==correct[jj]:
ifjj<leys:
leys=jj
returntreeset[leys], num
# 最终的函数
frompylabimportmpl
mpl.rcParams['font.sans-serif'] = ['FangSong'] # 显示中文
mpl.rcParams['axes.unicode_minus'] =False# 显示负号
importmatplotlib.pyplotasplt
# 根据不同的深度。看精确率的变化
if__name__=='__main__':
# 根据树的不同的初始深度,看正确率的变化
xunliande= []
yazhengde= []
yucede= []
forshenduinrange(2, 13):
uu=DT(tree_length=shendu)
# 完全成长的树
uu.grow_tree()
# 剪枝形成的树的集
gu=uu.prue_tree()
# 交叉验证形成的最好的树
cc=uu.jiaocha_tree(gu[0])
# 根据最好的树预测新的数据集的结果
uu.noderela=cc[0]
prenum=uu.pre_tree(uu.pre_dtdata)
# 验证的
yazhengde.append(cc[1])
# 预测的
yucede.append(uu.compuer_correct(uu.pre_dtdata[:, -1], prenum))
# 训练
trainnum=uu.pre_tree(uu.train_dtdata)
xunliande.append(uu.compuer_correct(uu.train_dtdata[:, -1], trainnum))
print(xunliande, yazhengde, yucede)
print('dddddddddddddddddddd', shendu)
# 绘制图
plt.plot(list(range(2, 13)), xunliande, 'o--', label='训练', lw=2)
plt.plot(list(range(2, 13)), yazhengde, '*--', label='验证', lw=2)
plt.plot(list(range(2, 13)), yucede, 's--', label='预测', lw=2)
plt.xlabel('树的初始深度')
plt.xlim(1, 14)
plt.ylabel('正确率')
plt.legend(shadow=True, fancybox=True)
plt.show()